package org.example.nebula.cluster

import com.facebook.thrift.protocol.TCompactProtocol
import com.vesoft.nebula.algorithm.lib.DegreeStaticAlgo
import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.example.nebula.basic.ReadData
import org.example.utils.VertexUtil.{convertStringId2LongId, reconvertLongId2StringId}

import scala.collection.mutable.ListBuffer

object RunDegree {

    def main(args: Array[String]): Unit = {
        val sparkConf = new SparkConf()
                .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
                .registerKryoClasses(Array[Class[_]](classOf[TCompactProtocol]))
        val spark = SparkSession
                .builder().appName("nebula-degree")
                .master("yarn")
                .config(sparkConf)
                .getOrCreate()

        val list: ListBuffer[String]= ListBuffer[String]()

        var start = System.nanoTime()
        val stringCsvDF = ReadData.readStringCsvData(spark)
        var end = System.nanoTime()
        list.append("读取文件耗时"+ (end - start)/1000000 + "ms")
        list.append("共" + stringCsvDF.count() + "条边")
        println(s"===================================读取文件===================================", "耗时" + (end - start)/1000000 + "ms")
        println("==========================共" + stringCsvDF.count() + "条边==========================")

        degreeWithIdMaping(spark, stringCsvDF, list)
        end = System.nanoTime()
        println(s"===================================共计===================================", "耗时" + (end - start)/1000000 + "ms")
        list.append("共计耗时" + (end - start)/1000000 + "ms")

        var result = spark.sparkContext.makeRDD(list)
        result.repartition(1).saveAsTextFile("/tmp/result_degree_5002_time.txt")

        spark.stop()
    }

    def degreeWithIdMaping(spark: SparkSession, df: DataFrame, list: ListBuffer[String]): Unit = {
        var start = System.nanoTime()
        val encodedDF      = convertStringId2LongId(df, list)
        var end = System.nanoTime()
        list.append("转换id耗时" + (end - start)/1000000 + "ms")
        println(s"===================================转换id===================================", "耗时" + (end - start)/1000000 + "ms")

        start = System.nanoTime()
        val degree = DegreeStaticAlgo.apply(spark, encodedDF)
        end = System.nanoTime()
        list.append("算法计算" + (end - start)/1000000 + "ms")
        println(s"===================================算法计算===================================", "耗时" + (end - start)/1000000 + "ms")

        start = System.nanoTime()
        val decodedResult      = reconvertLongId2StringId(spark, degree)
        end = System.nanoTime()
        list.append("映射id" + (end - start)/1000000 + "ms")
        println(s"===================================映射id===================================", "耗时" + (end - start)/1000000 + "ms")

        start = System.nanoTime()
        decodedResult.repartition(1).write.option("header", true).csv("/tmp/result_degree_5002.csv")
        end = System.nanoTime()
        list.append("结果写入文件" + (end - start)/1000000 + "ms")
        println(s"===================================结果写入文件===================================", "耗时" + (end - start)/1000000 + "ms")
    }

}
